import torch
import torch.distributions
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from utils.datasets.paths import get_svhn_path
from utils.datasets.svhn_augmentation import get_SVHN_augmentation
from utils.datasets.combo_dataset import ComboDataset
import os.path

DEFAULT_TRAIN_BATCHSIZE = 128
DEFAULT_TEST_BATCHSIZE = 128

SVHN_mean_int = ( int( 255 * 0.4377), int(255 * 0.4438), int(255 * 0.4728))
SVHN_mean = torch.tensor([0.4377, 0.4438, 0.4728])

val_idcs_file = 'svhn_validation_split.pth'
extra_idcs_file = 'svhn_extra_split.pth'

def generate_SVHN_val_extra_split(val_per_class=1000):
    svhn_path = get_svhn_path()
    svhn_extra = datasets.SVHN(svhn_path, split='extra', transform='none', download=True)
    num_classes = 10
    labels = svhn_extra.labels
    original_extra_per_class = torch.zeros(num_classes, dtype=torch.long)

    per_class_idcs = []

    for i in range(num_classes):
        per_class_idcs.append( torch.nonzero(torch.BoolTensor(labels == i)).squeeze() )
        original_extra_per_class[i] = len( per_class_idcs[i] )
        assert original_extra_per_class[i] > val_per_class
        print(f'Extra Samples per class {original_extra_per_class[i]} - Validation {val_per_class} - New Extra {original_extra_per_class[i] - val_per_class}')

    num_val = num_classes * val_per_class
    num_extra = len(svhn_extra) - num_val

    val_idcs = torch.zeros(num_val, dtype=torch.long)
    extra_idcs = torch.zeros(num_extra, dtype=torch.long)

    val_idx = 0
    extra_idx = 0
    for i in range(num_classes):
        val_idx_next = val_idx + val_per_class
        extra_idx_next = extra_idx + (original_extra_per_class[i] - val_per_class)

        shuffle_idcs = torch.randperm( original_extra_per_class[i] )

        val_idcs[val_idx:val_idx_next] = per_class_idcs[i][shuffle_idcs[:val_per_class]]
        extra_idcs[extra_idx:extra_idx_next] = per_class_idcs[i][shuffle_idcs[val_per_class:]]

        extra_idx = extra_idx_next
        val_idx = val_idx_next

    # Validate:
    validation_labels = labels[val_idcs]
    extra_labels = labels[extra_idcs]

    for class_idx in range(num_classes):
        assert np.sum(extra_labels == class_idx) == original_extra_per_class[class_idx] - val_per_class
        assert np.sum(validation_labels == class_idx) == val_per_class

    torch.save(val_idcs, val_idcs_file)
    torch.save(extra_idcs, extra_idcs_file)
    print('Split generation completed')


class SVHNValidationExtraSplit(Dataset):
    def __init__(self, path, split, transform=None, target_transform=None):
        self.svhn_extra = datasets.SVHN(path, split='extra', transform=transform, target_transform=target_transform, download=True)

        if split == 'extra-split':
            self.idcs = torch.load(extra_idcs_file)
            print(f'SVHN Extra split - Length {len(self.idcs)}')
        elif split == 'validation-split':
            self.idcs = torch.load(val_idcs_file)
            print(f'SVHN Validation split - Length {len(self.idcs)}')
        else:
            raise ValueError()

        # self.targets = []
        # for idx in self.idcs:
        #     self.targets.append( self.cifar.targets[idx])

        self.length = len(self.idcs)

    def __getitem__(self, index):
        extra_idx = self.idcs[index]
        return self.svhn_extra[extra_idx]

    def __len__(self):
        return self.length

class SVHNTrainPlusExtraSplit(ComboDataset):
    def __init__(self, path, transform=None, target_transform=None):
        svhn_train = datasets.SVHN(path, split='train', transform=transform, target_transform=target_transform)
        svhn_extra = SVHNValidationExtraSplit(path, 'extra-split',
                                                   transform=transform, target_transform=target_transform)

        super().__init__([svhn_train, svhn_extra])

def get_SVHNTrainPlusExtra(shuffle = True, batch_size=None, augm_type='none', num_workers=4, config_dict=None):
    if batch_size==None:
            batch_size=DEFAULT_TEST_BATCHSIZE

    augm_config = {}
    transform = get_SVHN_augmentation(augm_type, config_dict=augm_config)

    path = get_svhn_path()
    dataset = SVHNTrainPlusExtraSplit(path, transform)

    loader = DataLoader(dataset, batch_size=batch_size,
                   shuffle=shuffle, num_workers=num_workers)
    if config_dict is not None:
        if config_dict is not None:
            config_dict['Dataset'] = 'SVHN Train + Extra Split'
            config_dict['Shuffle'] = shuffle
            config_dict['Batch out_size'] = batch_size
            config_dict['Augmentation'] = augm_config

    return loader


def get_SVHNValidationExtraSplit(split='validation-split', shuffle = None, batch_size=None, augm_type='none', num_workers=4):
    if batch_size==None:
        if split in ['extra-split']:
            batch_size=DEFAULT_TRAIN_BATCHSIZE
        else:
            batch_size=DEFAULT_TEST_BATCHSIZE

    if shuffle is None:
        if split in ['extra-split']:
            shuffle = True
        else:
            shuffle = False

    transform = get_SVHN_augmentation(augm_type)

    path = get_svhn_path()
    dataset = SVHNValidationExtraSplit(path, split, transform)

    loader = DataLoader(dataset, batch_size=batch_size,
                   shuffle=shuffle, num_workers=num_workers)

    return loader

if __name__ == "__main__":
    faulty = [ ]

    val_idcs = torch.load(val_idcs_file)
    extra_idcs = torch.load(extra_idcs_file)

    path = get_svhn_path()
    svhn_extra = datasets.SVHN(path, split='extra', transform=None, target_transform=None)
    labels = torch.LongTensor(svhn_extra.labels)

    new_val = val_idcs.clone()
    new_extra = extra_idcs.clone()

    for i in faulty:
        label_i = labels[val_idcs[i]]
        available_idcs = torch.ones(len(svhn_extra), dtype=torch.bool)

        available_idcs[val_idcs] = 0
        available_idcs[new_val] = 0
        available_idcs[labels != label_i] = 0

        available_idcs_lin = torch.nonzero(available_idcs, as_tuple=False).squeeze()
        val_i = available_idcs_lin[torch.randint(len(available_idcs_lin), (1,))]

        assert labels[val_i] == label_i

        new_val[i] = val_i

        val_i_old_extra_idx = torch.nonzero(extra_idcs == val_i, as_tuple=False).squeeze()
        new_extra[val_i_old_extra_idx] == val_idcs[i]

    for i in range(10):
        assert torch.sum(labels[val_idcs] == i) ==  torch.sum(labels[new_val] == i)
        assert torch.sum(labels[extra_idcs] == i) ==  torch.sum(labels[new_extra] == i)

    torch.save(new_val, val_idcs_file)
    torch.save(new_extra, extra_idcs_file)
    print('Labels replaced')